import torch

def submodular_kernel_select_gpu(
    demo_embeds: torch.Tensor,  # (n, d) on GPU: feature vectors ϕ(x₁),…,ϕ(xₙ)
    test_embed:  torch.Tensor,  # (d,)   on GPU: feature vector ϕ(z)
    k:           int,
    lambd:       float,
    beta:        float = 0.02,
    lengthscale: float = 1.0
) -> list:
    """
    Greedy submodular example selection with kernelization (Section 2.2).

    Returns:
        List of indices of the k selected demos.
    """
    device = demo_embeds.device
    n, d   = demo_embeds.shape

    # 1) Compute full demo-demo Gram matrix K_ij = k(x_i, x_j) using RBF kernel
    #    k(x,x') = exp(-||x-x'||^2/(2 l^2))
    #    pairwise_sq[i,j] = ||x_i - x_j||^2
    XX = (demo_embeds**2).sum(dim=1)                         # (n,) = ||x_i||^2
    pairwise_sq = XX.unsqueeze(1) + XX.unsqueeze(0) - 2*(demo_embeds @ demo_embeds.T)
    K = torch.exp(-pairwise_sq / (2 * lengthscale**2))       # (n,n)

    # 2) Compute k_z = [k(z, x_i)]_{i=1}^n
    z_sq        = (test_embed**2).sum()                      # ||z||^2
    cross_sq    = z_sq + XX - 2*(demo_embeds @ test_embed)   # ||z - x_i||^2
    k_z         = torch.exp(-cross_sq / (2 * lengthscale**2))# (n,)

    # 3) For RBF, k(x_i,x_i) = 1  ∀i
    k_xx = torch.ones(n, device=device)                      # (n,)

    # 4) Prepare for greedy loop
    candidate_mask = torch.ones(n, dtype=torch.bool, device=device)
    selected       = []
    inv_K          = None  # will hold inv(K_S + βI) of shape (|S|×|S|)

    for _ in range(k):
        if inv_K is None:
            # No points selected yet:
            # numerator_i = [k(z,x_i)]^2                                    (eq 12, with k_S = 0)
            numerator = k_z.pow(2)                                      # (n,)
            # denom_i = β + k(x_i,x_i)                                     (eq 13, with k_S = 0)
            denom     = beta + k_xx                                     # (n,)
        else:
            m = len(selected)
            # 5) Extract k_S(z) and build u_z = inv_K @ k_S(z)
            kSz = k_z[selected]                                         # (m,)
            u_z = inv_K @ kSz                                           # (m,)

            # 6) For every candidate j, get k_S(x_j) = [k(x_{s}, x_j)]_s
            kSx = K[selected][:, :]                                     # (m,n)

            # 7) numerator_j = ( k(z,j) - k_S(z)^T inv_K k_S(x_j) )^2     (eq 12)
            cross_zx = (u_z.unsqueeze(1) * kSx).sum(dim=0)               # (n,)
            numerator = (k_z - cross_zx).pow(2)                          # (n,)

            # 8) denom_term_j = k_S(x_j)^T inv_K k_S(x_j)
            invK_kSx   = inv_K @ kSx                                     # (m,n)
            denom_term = (kSx * invK_kSx).sum(dim=0)                     # (n,)
            # 9) denom_j = β + k(x_j,x_j) - denom_term_j                   (eq 11 & 13)
            denom      = beta + k_xx - denom_term                        # (n,)

        # 10) Full submodular score: numerator/denom + λ * denom         (eq 12)+(13)
        scores = numerator / denom + lambd * denom                     # (n,)

        # Mask out already chosen
        scores = scores.masked_fill(~candidate_mask, float('-inf'))

        # 11) Pick best index
        best = torch.argmax(scores).item()
        selected.append(best)
        candidate_mask[best] = False

        # 12) Update inv_K = inv(K_{S} + βI) via block‐matrix inversion
        if inv_K is None:
            # First selected point: inv_K = 1 / (k(x_best,x_best) + β)
            inv_K = torch.tensor([[1.0 / (k_xx[best] + beta)]], device=device)
        else:
            # Let A = K_{old S} + βI  of size m×m, inv_A = inv_K
            # Let u = [k(x_s, x_new)]_{s∈old S}, c = k(x_new,x_new) + β
            A_inv = inv_K                                              # (m,m)
            u     = K[selected[:-1], best]                             # (m,)
            c     = k_xx[best] + beta                                  # scalar

            # Compute A_inv_u  and Schur complement S = c - u^T A_inv u
            A_inv_u = A_inv @ u                                        # (m,)
            S_scalar = c - (u * A_inv_u).sum()                         # scalar

            # Block‐inverse:
            # inv([[A, u],[u^T, c]]) =
            # [[ A^-1 + A^-1 u u^T A^-1 / S,   -A^-1 u / S      ],
            #  [   -(A^-1 u)^T / S,             1/S            ]]
            top_left     = A_inv + (A_inv_u.unsqueeze(1) @ A_inv_u.unsqueeze(0)) / S_scalar  # (m,m)
            top_right    = (-A_inv_u / S_scalar).unsqueeze(1)                                 # (m,1)
            bottom_left  = top_right.T                                                         # (1,m)
            bottom_right = torch.tensor([[1.0 / S_scalar]], device=device)                    # (1,1)

            # Assemble new inv_K
            inv_K = torch.cat([
                torch.cat([top_left,    top_right],    dim=1),
                torch.cat([bottom_left, bottom_right], dim=1)
            ], dim=0)  # (m+1, m+1)

    return selected



def submodular_poly_kernel_select_gpu(
    demo_embeds: torch.Tensor,    # (n, d) on GPU: ϕ(x₁),…,ϕ(xₙ)
    test_embed:  torch.Tensor,    # (d,)   on GPU: ϕ(z)
    k:           int,             # how many to select
    lambd:       float,           # diversity trade-off λ
    beta:        float = 0.02,    # regularization β
    degree:      int = 3,         # polynomial degree p
    alpha:       float = 1.0,     # scaling α
    coef0:       float = 1.0      # offset c₀
) -> list:
    """
    Greedy submodular selection with polynomial kernel (Sec 2.2 general kernelization) :contentReference[oaicite:2]{index=2}.

    Kernel: k(x,x') = (α xᵀx' + c₀)^p

    Uses Eqs (12)-(13) for the score:
      numerator_j   = [ k(z,x_j) - k_S(z)ᵀ (K_S+βI)⁻¹ k_S(x_j) ]²
      denom_j       = β + k(x_j,x_j) - k_S(x_j)ᵀ (K_S+βI)⁻¹ k_S(x_j)
      score_j       = numerator_j/denom_j + λ·denom_j
    """
    device = demo_embeds.device
    n, d   = demo_embeds.shape

    # 1) Compute full Gram matrix K_ij = k(x_i,x_j) = (α ⟨x_i,x_j⟩ + c₀)^p
    #    inner = demo_embeds @ demo_embeds.T  gives (n,n) of ⟨x_i,x_j⟩
    inner = demo_embeds @ demo_embeds.T                             # (n,n)
    K = (alpha * inner + coef0).pow(degree)                         # (n,n)

    # 2) Compute k_z = [k(z, x_i)] = (α ⟨z,x_i⟩ + c₀)^p
    cross = demo_embeds @ test_embed                                # (n,)
    k_z   = (alpha * cross + coef0).pow(degree)                     # (n,)

    # 3) Diagonal entries k(x_i,x_i) = (α ||x_i||² + c₀)^p
    xx_norm2 = (demo_embeds**2).sum(dim=1)                          # (n,)
    k_xx     = (alpha * xx_norm2 + coef0).pow(degree)              # (n,)

    # 4) Prepare greedy-selection
    candidate_mask = torch.ones(n, dtype=torch.bool, device=device)
    selected       = []
    inv_K          = None  # will store (K_S + βI)⁻¹

    for _ in range(k):
        if inv_K is None:
            # no S yet ⇒ k_S = 0
            numerator = k_z.pow(2)               # (n,)  [Eq 12 w/ k_S=0]
            denom     = beta + k_xx              # (n,)  [Eq 13 w/ k_S=0]
        else:
            m = len(selected)
            # 5) u_z = (K_S+βI)⁻¹ k_S(z)
            kSz = k_z[selected]                 # (m,)
            u_z = inv_K @ kSz                   # (m,)

            # 6) k_S(x_j) = K[selected, j]
            kSx = K[selected]                   # (m,n)

            # 7) numerator_j = [k(z,j) − k_S(z)ᵀ·u_j]²
            cross_zx = (u_z.unsqueeze(1) * kSx).sum(dim=0)  # (n,)
            numerator = (k_z - cross_zx).pow(2)             # (n,)

            # 8) denom_term_j = k_S(x_j)ᵀ (K_S+βI)⁻¹ k_S(x_j)
            invK_kSx   = inv_K @ kSx                       # (m,n)
            denom_term = (kSx * invK_kSx).sum(dim=0)       # (n,)

            # 9) denom_j = β + k_xx - denom_term
            denom = beta + k_xx - denom_term               # (n,)

        # 10) score = numerator/denom + λ·denom
        scores = numerator / denom + lambd * denom        # (n,)
        scores = scores.masked_fill(~candidate_mask, float('-inf'))

        # 11) pick best
        best = torch.argmax(scores).item()
        selected.append(best)
        candidate_mask[best] = False

        # 12) update inv_K via block-matrix inversion
        if inv_K is None:
            inv_K = torch.tensor([[1.0 / (k_xx[best] + beta)]], device=device)
        else:
            A_inv = inv_K                                  # (m,m)
            u     = K[selected[:-1], best]                # (m,)
            c     = k_xx[best] + beta                     # scalar

            A_inv_u = A_inv @ u                            # (m,)
            S_scalar = c - (u * A_inv_u).sum()             # scalar

            top_left     = A_inv + (A_inv_u.unsqueeze(1) @ A_inv_u.unsqueeze(0)) / S_scalar
            top_right    = (-A_inv_u / S_scalar).unsqueeze(1)
            bottom_left  = top_right.T
            bottom_right = torch.tensor([[1.0 / S_scalar]], device=device)

            inv_K = torch.cat([
                torch.cat([top_left,    top_right],    dim=1),
                torch.cat([bottom_left, bottom_right], dim=1)
            ], dim=0)

    return selected
